Deep Learning : Principles and Practices - CSE1016 - L33 + L34
Name: Nikhil V
Registration No: CH.EN.U4AIE22038
Lab - 7 : Transfer Learning and Fine Tuning on Plant Village Dataset
InΒ [12]:
!pip install split-folders
Requirement already satisfied: split-folders in /opt/conda/lib/python3.10/site-packages (0.5.1)
InΒ [13]:
!pip install tensorflow
Requirement already satisfied: tensorflow in /opt/conda/lib/python3.10/site-packages (2.16.1) Requirement already satisfied: absl-py>=1.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (1.4.0) Requirement already satisfied: astunparse>=1.6.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (1.6.3) Requirement already satisfied: flatbuffers>=23.5.26 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (24.3.25) Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (0.5.4) Requirement already satisfied: google-pasta>=0.1.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (0.2.0) Requirement already satisfied: h5py>=3.10.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (3.11.0) Requirement already satisfied: libclang>=13.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (18.1.1) Requirement already satisfied: ml-dtypes~=0.3.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (0.3.2) Requirement already satisfied: opt-einsum>=2.3.2 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (3.3.0) Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from tensorflow) (21.3) Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (3.20.3) Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (2.32.3) Requirement already satisfied: setuptools in /opt/conda/lib/python3.10/site-packages (from tensorflow) (70.0.0) Requirement already satisfied: six>=1.12.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (1.16.0) Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (2.4.0) Requirement already satisfied: typing-extensions>=3.6.6 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (4.12.2) Requirement already satisfied: wrapt>=1.11.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (1.16.0) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (1.64.1) Requirement already satisfied: tensorboard<2.17,>=2.16 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (2.16.2) Requirement already satisfied: keras>=3.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (3.3.3) Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (0.37.0) Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /opt/conda/lib/python3.10/site-packages (from tensorflow) (1.26.4) Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/conda/lib/python3.10/site-packages (from astunparse>=1.6.0->tensorflow) (0.43.0) Requirement already satisfied: rich in /opt/conda/lib/python3.10/site-packages (from keras>=3.0.0->tensorflow) (13.7.1) Requirement already satisfied: namex in /opt/conda/lib/python3.10/site-packages (from keras>=3.0.0->tensorflow) (0.0.8) Requirement already satisfied: optree in /opt/conda/lib/python3.10/site-packages (from keras>=3.0.0->tensorflow) (0.11.0) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (1.26.18) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (2024.8.30) Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.17,>=2.16->tensorflow) (3.6) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.17,>=2.16->tensorflow) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.17,>=2.16->tensorflow) (3.0.4) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->tensorflow) (3.1.2) Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/conda/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow) (2.1.5) Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.10/site-packages (from rich->keras>=3.0.0->tensorflow) (3.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.10/site-packages (from rich->keras>=3.0.0->tensorflow) (2.18.0) Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow) (0.1.2)
Importing the required modulesΒΆ
InΒ [14]:
# Modules used for data handling and visualisation
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import random as r
sns.set_style("whitegrid")
# Modules used for suppressing warnings
import warnings
warnings.filterwarnings('ignore')
# Modules used for dataset split
import splitfolders
import os
# Modules used for model training and transfer learning
import tensorflow as tf
from tensorflow.keras.layers import Dense,Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras import Model
InΒ [15]:
# Centering all the output images in the notebook.
from IPython.core.display import HTML as Center
Center(""" <style>
.output_png {
display: table-cell;
text-align: center;
vertical-align: middle;
}
</style> """)
Out[15]:
Dataset ExplorationΒΆ
InΒ [16]:
class Dataset:
def __init__(self, dataset_path : str):
self.PARENT = dataset_path
self.class_distribution = dict()
def __compute_class_distributions(self):
for dirname in os.listdir(self.PARENT):
self.class_distribution[dirname] = len(os.listdir(os.path.join(self.PARENT, dirname)))
def class_distributions(self):
self.__compute_class_distributions()
plt.figure(figsize=(10,10))
plt.bar(self.class_distribution.keys(),
self.class_distribution.values(),
color=["crimson","red","orange","yellow"])
plt.xticks(rotation=90)
plt.title("Class Distribution of PlantVillage dataset")
plt.xlabel("Class Label")
plt.ylabel("Frequency of class")
plt.show()
def show_class_samples(self):
rows = 5
columns = 3
c = 0
fig, axs = plt.subplots(rows, columns,figsize=(15,15))
for dirname in os.listdir(self.PARENT):
img_path = r.choice(os.listdir(os.path.join(self.PARENT, dirname)))
image = mpimg.imread(os.path.join(self.PARENT, dirname, img_path))
axs[c//columns, c%columns].imshow(image)
axs[c//columns, c%columns].set_title(dirname)
c += 1
fig.suptitle("Image Samples of Plant Village dataset")
plt.subplots_adjust(bottom=0.1, top=0.9, hspace=0.5)
plt.show()
Loading the datasetΒΆ
InΒ [17]:
plant_village = Dataset("/kaggle/input/plantdisease/PlantVillage")
Class DistributionΒΆ
InΒ [18]:
plant_village.class_distributions()
Sample ImagesΒΆ
InΒ [19]:
plant_village.show_class_samples()
Train, Test, Validation SplitΒΆ
InΒ [20]:
class DataSplit:
def __init__(self, dataset_path : str, destination_path : str, train : float, test : float, val : float) -> None:
self.PARENT = dataset_path
self.TRAIN = train
self.TEST = test
self.VAL = val
self.destination_path = destination_path
self.train_gen = None
self.test_gen = None
self.val_gen = None
self.TRAIN_DIR = "dataset/train"
self.TEST_DIR = "dataset/test"
self.VAL_DIR = "dataset/val"
def test_train_validation_split(self):
assert (self.TRAIN + self.TEST + self.VAL) == 1
splitfolders.ratio(input = self.PARENT,
output = self.destination_path,
seed = 1337, ratio = (.8, .1, .1),
group_prefix = None,
move = False)
def create_generators(self):
self.train_gen = tf.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function=tf.keras.applications.resnet50.preprocess_input,
)
self.test_gen = tf.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function=tf.keras.applications.resnet50.preprocess_input
)
self.val_gen = tf.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function=tf.keras.applications.resnet50.preprocess_input
)
def get_images(self):
train_images = self.train_gen.flow_from_directory(
directory=self.TRAIN_DIR,
target_size=(75, 75),
color_mode='rgb',
class_mode='categorical',
batch_size=32,
shuffle=True,
seed=42,
subset='training'
)
val_images = self.val_gen.flow_from_directory(
directory=self.VAL_DIR,
target_size=(75, 75),
color_mode='rgb',
class_mode='categorical',
batch_size=32,
shuffle=True,
seed=42
)
test_images = self.test_gen.flow_from_directory(
directory=self.TEST_DIR,
target_size=(75, 75),
color_mode='rgb',
class_mode='categorical',
batch_size=32,
shuffle=False,
seed=42
)
return train_images, val_images, test_images
InΒ [23]:
ds = DataSplit("/kaggle/input/plantdisease/PlantVillage","dataset",0.8,0.1, 0.1)
InΒ [24]:
ds.test_train_validation_split()
Copying files: 20639 files [03:39, 94.17 files/s]
Train Data InsightsΒΆ
InΒ [25]:
train = Dataset("dataset/train/")
InΒ [26]:
train.class_distributions()
InΒ [27]:
train.show_class_samples()
Test Data InsightsΒΆ
InΒ [28]:
test = Dataset("dataset/test/")
InΒ [29]:
test.class_distributions()
InΒ [30]:
test.show_class_samples()
Validation Data InsightsΒΆ
InΒ [31]:
val = Dataset("dataset/val/")
InΒ [32]:
val.class_distributions()
InΒ [33]:
val.show_class_samples()
Creating the data generatorsΒΆ
InΒ [34]:
ds.create_generators()
InΒ [35]:
train, val, test = ds.get_images()
Found 16504 images belonging to 15 classes. Found 2058 images belonging to 15 classes. Found 2076 images belonging to 15 classes.
Transfer LearningΒΆ
InΒ [36]:
class TransferLearning:
def __init__(self, train, val) -> None:
self.train = train
self.val = val
self.model = None
self.history = None
def load_model(self):
self.model = ResNet50(weights = 'imagenet',
include_top = False,
input_shape = (75,75,3))
def mark_layers_non_trainable(self):
for layer in self.model.layers:
layer.trainable = False
def add_final_layer(self):
self.x = Flatten()(self.model.output)
self.x = Dense(1000, activation='relu')(self.x)
self.predictions = Dense(15, activation = 'softmax')(self.x)
def compile_model(self):
self.model = Model(inputs = self.model.input, outputs = self.predictions)
self.model.compile(optimizer='adam', loss="categorical_crossentropy", metrics=['accuracy'])
def train_model(self):
self.history = self.model.fit(train,
batch_size=32,
epochs=10, validation_data=val)
def plot_history(self):
fig, axs = plt.subplots(2, 1, figsize=(15,15))
axs[0].plot(self.history.history['loss'])
axs[0].plot(self.history.history['val_loss'])
axs[0].title.set_text('Training Loss vs Validation Loss')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Loss')
axs[0].legend(['Train','Val'])
axs[1].plot(self.history.history['accuracy'])
axs[1].plot(self.history.history['val_accuracy'])
axs[1].title.set_text('Training Accuracy vs Validation Accuracy')
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('Accuracy')
axs[1].legend(['Train', 'Val'])
Transfer Learning using Resnet50ΒΆ
InΒ [37]:
tl = TransferLearning(train=train, val=val)
Loading the Resnet50 from Keras ApplicationΒΆ
InΒ [38]:
tl.load_model()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 94765736/94765736 ββββββββββββββββββββ 3s 0us/step
Making all the layers of the model non-trainableΒΆ
InΒ [43]:
tl.mark_layers_non_trainable()
Adding a final layer for classification of 15 classesΒΆ
InΒ [40]:
tl.add_final_layer()
Compiling modelΒΆ
InΒ [41]:
tl.compile_model()
Training modelΒΆ
InΒ [44]:
tl.train_model()
Epoch 1/10 516/516 ββββββββββββββββββββ 323s 626ms/step - accuracy: 0.9306 - loss: 0.2121 - val_accuracy: 0.8800 - val_loss: 0.4038 Epoch 2/10 516/516 ββββββββββββββββββββ 317s 614ms/step - accuracy: 0.9549 - loss: 0.1300 - val_accuracy: 0.9018 - val_loss: 0.3809 Epoch 3/10 516/516 ββββββββββββββββββββ 330s 630ms/step - accuracy: 0.9622 - loss: 0.1127 - val_accuracy: 0.9193 - val_loss: 0.3064 Epoch 4/10 516/516 ββββββββββββββββββββ 383s 632ms/step - accuracy: 0.9793 - loss: 0.0639 - val_accuracy: 0.8683 - val_loss: 0.6121 Epoch 5/10 516/516 ββββββββββββββββββββ 329s 637ms/step - accuracy: 0.9698 - loss: 0.0995 - val_accuracy: 0.9033 - val_loss: 0.4597 Epoch 6/10 516/516 ββββββββββββββββββββ 338s 654ms/step - accuracy: 0.9766 - loss: 0.0749 - val_accuracy: 0.9004 - val_loss: 0.5113 Epoch 7/10 516/516 ββββββββββββββββββββ 337s 653ms/step - accuracy: 0.9688 - loss: 0.1104 - val_accuracy: 0.9106 - val_loss: 0.4735 Epoch 8/10 516/516 ββββββββββββββββββββ 324s 627ms/step - accuracy: 0.9825 - loss: 0.0582 - val_accuracy: 0.9062 - val_loss: 0.5447 Epoch 9/10 516/516 ββββββββββββββββββββ 321s 620ms/step - accuracy: 0.9794 - loss: 0.0754 - val_accuracy: 0.9266 - val_loss: 0.4916 Epoch 10/10 516/516 ββββββββββββββββββββ 319s 617ms/step - accuracy: 0.9849 - loss: 0.0571 - val_accuracy: 0.9130 - val_loss: 0.6169
InΒ [61]:
tl.model.save("models/first_model.h5")
InΒ [46]:
CLASS_NAMES = list(train.class_indices.keys())
CLASS_NAMES
Out[46]:
['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']
InΒ [47]:
from sklearn.metrics import accuracy_score, classification_report
InΒ [48]:
predictions = np.argmax(tl.model.predict(test), axis=1)
65/65 ββββββββββββββββββββ 29s 413ms/step
InΒ [49]:
acc = accuracy_score(test.labels, predictions)
cm = tf.math.confusion_matrix(test.labels, predictions)
clr = classification_report(test.labels, predictions, target_names=CLASS_NAMES)
print("Test Accuracy: {:.3f}%".format(acc * 100))
plt.figure(figsize=(8, 8))
sns.heatmap(cm, annot=True, fmt='g', vmin=0, cmap='Blues', cbar=False)
plt.xticks(ticks= np.arange(15) + 0.5, labels=CLASS_NAMES, rotation=90)
plt.yticks(ticks= np.arange(15) + 0.5, labels=CLASS_NAMES, rotation=0)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()
Test Accuracy: 92.582%
InΒ [50]:
print(clr)
precision recall f1-score support
Pepper__bell___Bacterial_spot 0.99 0.77 0.87 101
Pepper__bell___healthy 0.93 0.98 0.95 149
Potato___Early_blight 0.99 0.90 0.94 100
Potato___Late_blight 0.84 0.92 0.88 100
Potato___healthy 1.00 0.56 0.72 16
Tomato_Bacterial_spot 0.93 0.97 0.95 214
Tomato_Early_blight 0.85 0.80 0.82 100
Tomato_Late_blight 0.92 0.87 0.90 192
Tomato_Leaf_Mold 0.96 0.92 0.94 96
Tomato_Septoria_leaf_spot 0.91 0.97 0.94 178
Tomato_Spider_mites_Two_spotted_spider_mite 0.87 0.95 0.91 169
Tomato__Target_Spot 0.85 0.94 0.89 141
Tomato__Tomato_YellowLeaf__Curl_Virus 0.97 0.97 0.97 322
Tomato__Tomato_mosaic_virus 1.00 0.87 0.93 38
Tomato_healthy 0.99 0.97 0.98 160
accuracy 0.93 2076
macro avg 0.93 0.89 0.91 2076
weighted avg 0.93 0.93 0.93 2076
Plotting the Learning CurvesΒΆ
InΒ [51]:
tl.plot_history()
- We can observe that model achieves an accuracy of 97.50% and 90.43% on training and validation sets respectively.
- Moreover, we can also gauge that the model is overfitting slightly which can be handled by fine tuning the model using regularization and re-training the layers
Fine-tuningΒΆ
Fine Tuning is the approach in which a pretrained model is used. However, few of the layers are made trainable to understand the patterns in the current dataset. Morevoer, regularization can also be added in the form of dropout layers.
InΒ [52]:
class FineTuning:
def __init__(self, train, val) -> None:
self.train = train
self.val = val
self.model = None
self.history = None
self.fine_tune_from = 100
def load_model(self):
self.model = ResNet50(weights = 'imagenet',
include_top = False,
input_shape = (75,75,3))
def fine_tune(self):
for layer in self.model.layers[:self.fine_tune_from]:
layer.trainable = False
for layer in self.model.layers[self.fine_tune_from:]:
layer.trainable = True
def add_final_layer(self):
self.x = Flatten()(self.model.output)
self.x = Dense(1000, activation='relu')(self.x)
self.predictions = Dense(15, activation = 'softmax')(self.x)
def compile_model(self):
self.model = Model(inputs = self.model.input, outputs = self.predictions)
self.model.compile(optimizer='adam', loss="categorical_crossentropy", metrics=['accuracy'])
def train_model(self):
self.history = self.model.fit(train,
batch_size=32,
epochs=5,
validation_data=val,
callbacks=[
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
)
])
def plot_history(self):
fig, axs = plt.subplots(2, 1, figsize=(15,15))
axs[0].plot(self.history.history['loss'])
axs[0].plot(self.history.history['val_loss'])
axs[0].title.set_text('Training Loss vs Validation Loss')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Loss')
axs[0].legend(['Train','Val'])
axs[1].plot(self.history.history['accuracy'])
axs[1].plot(self.history.history['val_accuracy'])
axs[1].title.set_text('Training Accuracy vs Validation Accuracy')
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('Accuracy')
axs[1].legend(['Train', 'Val'])
Fine Tuning the ResNet50 modelΒΆ
InΒ [53]:
ft = FineTuning(train,val)
Loading the ResNet50 model from keras applicationsΒΆ
InΒ [54]:
ft.load_model()
Making last 75 layers of the ResNet50 model trainableΒΆ
InΒ [55]:
ft.fine_tune()
Adding a final layer for classification of 15 classesΒΆ
InΒ [56]:
ft.add_final_layer()
Compiling the modelΒΆ
InΒ [57]:
ft.compile_model()
Training the model for 5 epochsΒΆ
InΒ [58]:
ft.train_model()
Epoch 1/5 516/516 ββββββββββββββββββββ 837s 2s/step - accuracy: 0.7715 - loss: 1.3325 - val_accuracy: 0.9397 - val_loss: 0.1707 Epoch 2/5 516/516 ββββββββββββββββββββ 804s 2s/step - accuracy: 0.9495 - loss: 0.1735 - val_accuracy: 0.9485 - val_loss: 0.1661 Epoch 3/5 516/516 ββββββββββββββββββββ 862s 2s/step - accuracy: 0.9762 - loss: 0.1016 - val_accuracy: 0.9475 - val_loss: 0.1831 Epoch 4/5 516/516 ββββββββββββββββββββ 821s 2s/step - accuracy: 0.9823 - loss: 0.0631 - val_accuracy: 0.9534 - val_loss: 0.2221 Epoch 5/5 516/516 ββββββββββββββββββββ 810s 2s/step - accuracy: 0.9842 - loss: 0.0567 - val_accuracy: 0.9543 - val_loss: 0.1700
InΒ [62]:
ft.model.save("models/second_model.h5")
InΒ [60]:
predictions = np.argmax(ft.model.predict(test), axis=1)
65/65 ββββββββββββββββββββ 29s 414ms/step
InΒ [63]:
acc = accuracy_score(test.labels, predictions)
cm = tf.math.confusion_matrix(test.labels, predictions)
clr = classification_report(test.labels, predictions, target_names=CLASS_NAMES)
print("Test Accuracy: {:.3f}%".format(acc * 100))
plt.figure(figsize=(8, 8))
sns.heatmap(cm, annot=True, fmt='g', vmin=0, cmap='Blues', cbar=False)
plt.xticks(ticks= np.arange(15) + 0.5, labels=CLASS_NAMES, rotation=90)
plt.yticks(ticks= np.arange(15) + 0.5, labels=CLASS_NAMES, rotation=0)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()
Test Accuracy: 95.376%
InΒ [64]:
print(clr)
precision recall f1-score support
Pepper__bell___Bacterial_spot 0.95 0.99 0.97 101
Pepper__bell___healthy 0.99 0.99 0.99 149
Potato___Early_blight 0.97 0.98 0.98 100
Potato___Late_blight 0.99 0.76 0.86 100
Potato___healthy 0.79 0.94 0.86 16
Tomato_Bacterial_spot 0.97 0.96 0.97 214
Tomato_Early_blight 0.91 0.92 0.92 100
Tomato_Late_blight 0.88 0.95 0.92 192
Tomato_Leaf_Mold 0.98 0.96 0.97 96
Tomato_Septoria_leaf_spot 0.98 0.96 0.97 178
Tomato_Spider_mites_Two_spotted_spider_mite 0.92 0.96 0.94 169
Tomato__Target_Spot 0.97 0.87 0.92 141
Tomato__Tomato_YellowLeaf__Curl_Virus 0.97 0.99 0.98 322
Tomato__Tomato_mosaic_virus 0.90 1.00 0.95 38
Tomato_healthy 0.97 0.99 0.98 160
accuracy 0.95 2076
macro avg 0.94 0.95 0.94 2076
weighted avg 0.96 0.95 0.95 2076
Evaluation of the fine-tuned modelΒΆ
InΒ [65]:
ft.model.evaluate(test)
65/65 ββββββββββββββββββββ 25s 389ms/step - accuracy: 0.9519 - loss: 0.1647
Out[65]:
[0.1544826626777649, 0.9537572264671326]
Plotting the learning curves of the fine-tuning processΒΆ
InΒ [66]:
ft.plot_history()
ConclusionΒΆ
- Transfer Learning is approach of using a model pretrained(i.e. ResNet50) on a large dataset(here, imagenet) and using its knowledge for our case.
- As inferred earlier, transfer learning gives an accuracy of 97.50% and 90.43% on training and validation sets respectively which shows that the model is slightly overfitted resulting in the requirement of fine tuning of the model.
- The model is fine tuned by letting the last 75 layers learn the patterns in the dataset and overcome the overfitting and improve the accuracy.
- The fine tuned model gives an accuracy of 98.09%, 94.31%, and 95.23% on train, validation, and test splits.
- On a final note, in the deep learning there is required of the large overhead of time and hardware requirements for Fine Tuning of the model.